package com.ihome.framework.core.mq;

import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.LinkedBlockingDeque;
import java.util.concurrent.ThreadFactory;
import java.util.concurrent.ThreadPoolExecutor;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.atomic.AtomicInteger;

import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.beans.factory.InitializingBean;

import com.google.common.util.concurrent.ThreadFactoryBuilder;
import com.ihome.framework.core.log.Log;
import com.ihome.framework.core.log.LogOp;
import com.ihome.framework.core.utils.CoreUtils;
import com.netease.cloud.nqs.client.ClientConfig;
import com.netease.cloud.nqs.client.SimpleMessageSessionFactory;
import com.netease.cloud.nqs.client.consumer.ConsumerConfig;
import com.netease.cloud.nqs.client.consumer.MessageConsumer;
import com.netease.cloud.nqs.client.consumer.MessageHandler;
import com.netease.cloud.nqs.client.exception.MessageClientException;
import com.rabbitmq.client.ShutdownSignalException;

/**
 * RabbitMQ消费者的简单封装，1个队列可以启动多个消费者线程
 * 
 * @author zhengxiaohong
 */
public class MQMultiThreadConsumer implements InitializingBean{

    private MQConfig mqConfig;

    private MQConsumerConfig mqConsumerConfig;

    private SimpleMessageSessionFactory sessionFactory;

    private Map<String, MessageConsumer> consumerMap;

    private Map<String, MQConsumerRunner> consumerRunnerMap;

    private Map<String, AtomicInteger> counterMap;

    private final Logger logger = LoggerFactory.getLogger(getClass());

    private final static long RE_CONN_INTERVAL = 3000;

    /**
     * 默认最大线程数量
     */
    private final static int DEFAULT_MAX_THREAD = 200;

    /**
     * 最大的线程数量，防止误操作启动太多的线程
     */
    private int maxThread = DEFAULT_MAX_THREAD;


    private AtomicBoolean isInit = new AtomicBoolean(false);
    
    


    /**
     * 新建消费者
     * 
     * @param mqConfig
     *            RabbitMQ配置
     * @param mqConsumerConfig
     *            消费者配置
     */
    public MQMultiThreadConsumer(MQConfig mqConfig, MQConsumerConfig mqConsumerConfig) {
        this.mqConfig = mqConfig;
        this.mqConsumerConfig = mqConsumerConfig;
    }

    /* (non-Javadoc)
     * @see org.springframework.beans.factory.InitializingBean#afterPropertiesSet()
     */
    @Override
    public void afterPropertiesSet() throws Exception {
        init();
    }
    
    /**
     * 初始化
     */
    public void init() {
        if(isInit.compareAndSet(false, true)){
            if (mqConfig == null) {
                throw new MQException("mqConfig can not be null");
            }
            if (mqConsumerConfig == null) {
                throw new MQException("mqConsumerConfig can not be null");
            }
            consumerMap = new ConcurrentHashMap<String, MessageConsumer>();
            consumerRunnerMap = new ConcurrentHashMap<String, MQConsumerRunner>();
            counterMap = new ConcurrentHashMap<String, AtomicInteger>();
            init(mqConfig, mqConsumerConfig);
        }
    }

    /**
     * @param mqConfig
     * @param mqConsumerConfig
     */
    private synchronized void init(MQConfig mqConfig, MQConsumerConfig mqConsumerConfig) {
        ClientConfig cc = mqConfig.getClientConfig();
        sessionFactory = new SimpleMessageSessionFactory(cc);
    }

    /**
     * 初始化消费者连接，分配新的消费者id
     * 
     * @param queueName
     * @return
     */
    private synchronized String initConsumer(String queueName) {
        ConsumerConfig config = new ConsumerConfig();
        config.setGroup("test");
        config.setProductId(mqConfig.getExchange());
        config.setPrefetchCount(mqConsumerConfig.getPrefetchCount());
        config.setRequireAck(true);
        config.setQueueName(queueName);

        // 判断线程数量
        int threadNum = consumerRunnerMap.size();
        if (threadNum >= maxThread) {
            logger.warn(Log.op(LogOp.MQ_CONSUMER_INIT).msg("exceed max thread").kv("maxThread", maxThread)
                    .kv("threadNum", threadNum).toString());
            return null;
        }
        // 新建连接
        try {
            MessageConsumer messageConsumer = sessionFactory.createConsumer(config);
            AtomicInteger counter = counterMap.get(queueName);
            if (counter == null) {
                counter = new AtomicInteger();
                counterMap.put(queueName, counter);
            }
            String consumerId = queueName + "-" + counter.incrementAndGet();
            consumerMap.put(consumerId, messageConsumer);
            logger.info(Log.op(LogOp.MQ_CONSUMER_INIT).kv("queue", queueName).toString());
            return consumerId;
        } catch (MessageClientException e) {
            sessionFactory.shutdown();
        }
        return null;
    }

    /**
     * 初始化消费者连接，使用已有的消费者id，用于连接重连
     * 
     * @param queueName
     * @param consumerId
     * @return
     */
    private synchronized String initConsumerWithId(String queueName, String consumerId) {
        ConsumerConfig config = new ConsumerConfig();
        config.setGroup("test");
        config.setProductId(mqConfig.getExchange());
        config.setPrefetchCount(mqConsumerConfig.getPrefetchCount());
        config.setRequireAck(true);
        config.setQueueName(queueName);

        // 判断线程数量
        int threadNum = consumerRunnerMap.size();
        if (threadNum >= maxThread) {
            logger.warn(Log.op(LogOp.MQ_CONSUMER_INIT).msg("exceed max thread").kv("maxThread", maxThread)
                    .kv("threadNum", threadNum).toString());
            return null;
        }
        // 新建连接
        try {
            MessageConsumer messageConsumer = sessionFactory.createConsumer(config);
            consumerMap.put(consumerId, messageConsumer);
            logger.info(Log.op(LogOp.MQ_CONSUMER_INIT).kv("queue", queueName).toString());
            return consumerId;
        } catch (MessageClientException e) {
            sessionFactory.shutdown();
        }
        return null;
    }

    /**
     * 启动1个线程来消费某个队列的消息，每次调用都启动1个新的线程
     * 
     * @param queueName
     *            队列名称
     * @param handler
     *            处理消息的handler
     */
    public void consumeMessage(String queueName, MessageHandler handler) {
        String consumerId = initConsumer(queueName);
        logger.info(Log.op(LogOp.MQ_CONSUME_MSG).kv("queue", queueName).kv("consumerId", consumerId).toString());
        if (consumerId == null) {
            throw new MQException("init initConsumer fail, queueName:" + queueName);
        }
        MessageConsumer messageConsumer = consumerMap.get(consumerId);
        if (messageConsumer == null) {
            throw new MQException("init initConsumer fail, queueName:" + queueName);
        }
        startMQConsumerThread(consumerId, queueName, handler, messageConsumer);
    }

    /**
     * 启动1个线程来消费某个队列的消息
     * 
     * @param queueName
     *            队列名称
     * @param handler
     *            处理消息的handler
     */
    public void consumeMessageWithTraceId(String queueName, MessageWithTraceHandler handler) {
        String consumerId = initConsumer(queueName);
        logger.info(Log.op(LogOp.MQ_CONSUME_MSG).kv("queue", queueName).kv("consumerId", consumerId).toString());
        if (consumerId == null) {
            throw new MQException("init initConsumer fail, queueName:" + queueName);
        }
        MessageConsumer messageConsumer = consumerMap.get(queueName);
        if (messageConsumer == null) {
            throw new MQException("init initConsumer fail, queueName:" + queueName);
        }
        startMQConsumerThread(consumerId, queueName, handler, messageConsumer);
    }

    /**
     * 启动mq消费者线程
     * 
     * @param consumerId
     * @param queueName
     * @param handler
     * @param messageConsumer
     */
    private void startMQConsumerThread(String consumerId, String queueName, MessageHandler handler,
            MessageConsumer messageConsumer) {
        MQConsumerRunner runner = new MQConsumerRunner(consumerId, queueName, messageConsumer, handler);
        consumerRunnerMap.put(consumerId, runner);
        new Thread(runner).start();
        logger.info(
                Log.op(LogOp.MQ_CONSUME_THREAD).kv("consumerId", consumerId).kv("threadName", runner.currentThreadName()).toString());
    }

    /**
     * 销毁消费者
     */
    public synchronized void destory() {
        logger.info(Log.op(LogOp.MQ_CONSUME_SHUTDOWN).msg("MQConsumer Factory destory start").toString());
        if (sessionFactory != null) {
            sessionFactory.shutdown();
        }
        if (consumerMap == null || consumerMap.isEmpty()) {
            return;
        }
        for (String consumerId : consumerMap.keySet()) {
            MQConsumerRunner runner = consumerRunnerMap.get(consumerId);
            if (runner != null) {
                runner.shutdown();
            }
            MessageConsumer messageConsumer = consumerMap.get(consumerId);
            if (messageConsumer != null) {
                messageConsumer.shutdown();
            }
            logger.info(Log.op(LogOp.MQ_CONSUME_SHUTDOWN).msg("shutdown done").kv("consumerId", consumerId).toString());
        }
        consumerRunnerMap.clear();
        consumerMap.clear();
        counterMap.clear();
        logger.info(Log.op(LogOp.MQ_CONSUME_SHUTDOWN).msg("MQConsumer Factory destory done").toString());
    }

    /**
     * 消费者线程
     * 
     * @author zhengxiaohong
     * 
     */
    class MQConsumerRunner implements Runnable {
        String consumerId;
        String queueName;
        MessageConsumer messageConsumer;
        MessageHandler handler;
        private volatile boolean running = true;

        public MQConsumerRunner(String consumerId, String queueName, MessageConsumer messageConsumer,
                MessageHandler handler) {
            this.consumerId = consumerId;
            this.queueName = queueName;
            this.messageConsumer = messageConsumer;
            this.handler = handler;
        }
        
        protected String currentThreadName(){
           return "framework-mq-consumer-" + consumerId;
        }
        

        @Override
        public void run() {
            Thread.currentThread().setName(currentThreadName());
            while (running) {
                try {
                    messageConsumer.consumeMessage(handler);
                } catch (ShutdownSignalException ex) {
                    logger.warn(Log.op(LogOp.MQ_CONSUME_FAIL).toString(), ex);
                    messageConsumer = reConnect(consumerId, queueName, messageConsumer, handler);
                } catch (MessageClientException ex) {
                    logger.warn(Log.op(LogOp.MQ_CONSUME_FAIL).toString(), ex);
                    messageConsumer = reConnect(consumerId, queueName, messageConsumer, handler);
                }
            }
            // 关闭最后的连接, 清空队列
            logger.info(Log.op(LogOp.MQ_CONSUME_HOOK).msg("thread exit").toString());
        }

        private MessageConsumer reConnect(String consumerId, String queueName, MessageConsumer oldConsumer,
                MessageHandler handler) {
            // 如果进程要退出的话, 就不需要重建连接了
            if (!running) {
                return oldConsumer;
            }
            consumerMap.remove(consumerId);
            // 关闭旧连接
            oldConsumer.shutdown();
            // TODO 重连次数越多的话,重连间隔应该增长
            while (!consumerMap.containsKey(consumerId)) {
                logger.warn(Log.op(LogOp.MQ_CONSUME_RE_CONN).msg("try to reConnect").kv("consumerId", consumerId)
                        .toString());
                initConsumerWithId(queueName, consumerId);
                if (consumerMap.containsKey(consumerId)) {
                    break;
                } else {
                    CoreUtils.sleep(RE_CONN_INTERVAL);
                }
            }
            logger.warn(Log.op(LogOp.MQ_CONSUME_RE_CONN).msg("reConnect suc").kv("consumerId", consumerId).toString());
            return consumerMap.get(consumerId);
        }

        private void shutdown() {
            this.running = false;
        }
    }

}
